{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Utilities.io import DataLoader\n",
    "from Utilities.lossMetric import *\n",
    "from Utilities.trainVal import MinMaxGame\n",
    "from Models.RRDBNet import RRDBNet\n",
    "from Models.GAN import Discriminator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load in the training dataset\n",
    "I used the Chinese City Parking Dataset for this project. Please download the dataset from https://github.com/detectRecog/CCPD  \n",
    "Before loading the dataset, it is critical that you run the preprocessing script (preprocess.py) first!!!   \n",
    "` \n",
    "python preprocess.py 5 PATH_TO_UNZIPPED_DATA PATH_TO_OUTPUT_DIR\n",
    "`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import glob\n",
    "PATH = 'PATH_TO_OUTPUT_DIR/192_96' # only use images with shape 192 by 96 for training\n",
    "files = glob.glob(PATH + '/*.jpg') * 3  # data augmentation, same image with different brightness and contrast\n",
    "np.random.shuffle(files)\n",
    "train, val = files[:int(len(files)*0.8)], files[int(len(files)*0.8):]\n",
    "loader = DataLoader()\n",
    "trainData = DataLoader().load(train, batchSize=16)\n",
    "valData = DataLoader().load(val, batchSize=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "discriminator = Discriminator()\n",
    "extractor = buildExtractor()\n",
    "generator = RRDBNet(blockNum=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*  It's a good idea to pretrain the generator model before the min-max game - Reference: https://arxiv.org/abs/1701.00160"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train for 300 steps, validate for 100 steps\n",
      "300/300 [==============================] - 51s 170ms/step - loss: 1.0749 - psnr: 13.9406 - ssim: 0.2909 - val_loss: 0.7625 - val_psnr: 15.3406 - val_ssim: 0.3599\n"
     ]
    }
   ],
   "source": [
    "# a simple custom loss function that combines MAE loss with VGG loss, as defined in the SRGAN paper\n",
    "def contentLoss(y_true, y_pred):\n",
    "    featurePred = extractor(y_pred)\n",
    "    feature = extractor(y_true)\n",
    "    mae = tf.reduce_mean(tfk.losses.mae(y_true, y_pred))\n",
    "    return 0.1*tf.reduce_mean(tfk.losses.mse(featurePred, feature)) + mae\n",
    "\n",
    "optimizer = tfk.optimizers.Adam(learning_rate=1e-3)\n",
    "generator.compile(loss=contentLoss, optimizer=optimizer, metrics=[psnr, ssim])\n",
    "# epoch is set to 1 for demonstration purpose. In practice I found 20 is a good number\n",
    "# When the model reaches PSNR=20/ssim=0.65, we can start the min-max game\n",
    "history = generator.fit(x=trainData, validation_data=valData, epochs=1, steps_per_epoch=300, validation_steps=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generative adverserial network training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f22b53e2d6f411681e4657286eaf7d5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d485872accfe4f8fa51cda81ad53b5f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# training parameter. epoch is set to 1 for demonstration\n",
    "# please train the network utill it reaches snRatio ~= 22 \n",
    "PARAMS = dict(lrGenerator = 1e-4, \n",
    "              lrDiscriminator = 1e-4,\n",
    "              epochs = 1, \n",
    "              stepsPerEpoch = 500, \n",
    "              valSteps = 100)\n",
    "game = MinMaxGame(generator, discriminator, extractor)\n",
    "log, valLog = game.train(trainData, valData, PARAMS)\n",
    "# ideally peak signal noise ratio(snRation or psnr) should reach ~22"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save the model\n",
    "Because I defined the model as inherited class of tf keras model, they cannot be safely serialized.  \n",
    "Therefore, please save the weights only and follow the instructions in tutorial 1 to reload the model  \n",
    "You can found my pretrained model in the *Pretrained* folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#generator.save_weights(YOUR_PATH), save_format='tf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
